load("@aspect_bazel_lib//lib:tar.bzl", "mtree_spec", "tar")
load("@aspect_bazel_lib//lib:transitions.bzl", "platform_transition_filegroup")
load("@bazel_skylib//rules:build_test.bzl", "build_test")
load("@rules_oci//oci:defs.bzl", "oci_image", "oci_load", "oci_push")
load("@rules_zig//zig:defs.bzl", "zig_binary")

zig_binary(
    name = "llama",
    srcs = [
        "llama.zig",
    ],
    main = "main.zig",
    visibility = ["//visibility:public"],
    deps = [
        "//async",
        "//stdx",
        "//zml",
        "@com_github_hejsil_clap//:clap",
    ],
)

# TODO(steeve): Re-enable with a small, tokenless huggingface model.
#
# zig_test(
#     name = "test-implementation",
#     srcs = ["llama.zig"],
#     args = [
#         "--weights=$(location @Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json)",
#         "--config=$(location @Meta-Llama-3.1-8B-Instruct//:config.json)",
#     ],
#     data = [
#         "@Meta-Llama-3.1-8B-Instruct//:config.json",
#         "@Meta-Llama-3.1-8B-Instruct//:model.safetensors.index.json",
#     ],
#     main = "test.zig",
#     deps = [
#         "//async",
#         "//stdx",
#         "//zml",
#     ],
# )

# native_test(
#     name = "test_tokenizer",
#     src = "//zml/tokenizer:main",
#     # Note: all Llama-3.x tokenizers are the same,
#     # but using the 3.2-1B version because downloading the tokenizer triggers downloading the model.
#     args = [
#         "--tokenizer=$(location @Meta-Llama-3.2-1B-Instruct//:tokenizer.json)",
#         """--prompt='Examples of titles:
# 📉 Stock Market Trends
# 🍪 Perfect Chocolate Chip Recipe
# Evolution of Music Streaming
# Remote Work Productivity Tips
# Artificial Intelligence in Healthcare
# 🎮 Video Game Development Insights
# '""",
#         # this correspond to encoding with HF tokenizers, with bos=False
#         "--expected=41481,315,15671,512,9468,241,231,12937,8152,50730,198,9468,235,103,24118,39520,32013,26371,198,35212,3294,315,10948,45910,198,25732,5664,5761,1968,26788,198,9470,16895,22107,304,39435,198,9468,236,106,8519,4140,11050,73137,198",
#     ],
#     data = ["@Meta-Llama-3.2-1B-Instruct//:tokenizer.json"],
# )

mtree_spec(
    name = "mtree",
    srcs = [":llama"],
)

tar(
    name = "archive",
    srcs = [":llama"],
    args = [
        "--options",
        ",".join([
            "zstd:compression-level=9",
            "zstd:threads=16",
        ]),
    ],
    compress = "zstd",
    mtree = ":mtree",
)

oci_image(
    name = "image_",
    base = "@distroless_cc_debian12",
    entrypoint = ["/{}/llama".format(package_name())],
    target_compatible_with = [
        "@platforms//os:linux",
    ],
    tars = [":archive"],
)

platform_transition_filegroup(
    name = "image",
    srcs = [":image_"],
    tags = ["manual"],
    target_platform = "//platforms:linux_amd64",
)

oci_load(
    name = "load",
    image = ":image",
    repo_tags = ["zmlai/llama:latest"],
    tags = ["manual"],
)

oci_push(
    name = "push",
    image = ":image",
    remote_tags = ["latest"],
    repository = "index.docker.io/zmlai/llama",
    tags = ["manual"],
)

build_test(
    name = "test",
    targets = [":llama"],
)
